# -*-Python-*-

# Add task adapters to the end of each Transformer layer block. These adapters
# are as described in
# "Simple, Scalable Adaptation for Neural Machine Translation"
# Basically, when fine-tuning you add another layer norm + DenseReluDense +
# residual at the end of each Transformer layer block. Then, you only update the
# parameters for these added layers for a single task.
# In order to use task adapters correctly, a pre-trained model checkpoint path
# must also be supplied to tpu_estimator_model_fn.init_checkpoint. This .gin
# file assumes this is being done via an external flag.

import mesh_tensorflow.transformer.transformer_layers

# for Unitransformer models (e.g. language-model)
transformer.make_layer_stack.layers = [
    @mesh_tensorflow.transformer.transformer_layers.SelfAttention,
    @mesh_tensorflow.transformer.transformer_layers.DenseReluDense,
    ("task_adapter", @task_adapter/mesh_tensorflow.transformer.transformer_layers.DenseReluDense),
]

# for Bitransformer models (two-stack sequence-to-sequence models)
encoder/transformer.make_layer_stack.layers = [
    @mesh_tensorflow.transformer.transformer_layers.SelfAttention,
    @mesh_tensorflow.transformer.transformer_layers.DenseReluDense,
    ("task_adapter", @task_adapter/mesh_tensorflow.transformer.transformer_layers.DenseReluDense),
]
decoder/transformer.make_layer_stack.layers = [
    @mesh_tensorflow.transformer.transformer_layers.SelfAttention,
    @mesh_tensorflow.transformer.transformer_layers.EncDecAttention,
    @mesh_tensorflow.transformer.transformer_layers.DenseReluDense,
    ("task_adapter", @task_adapter/mesh_tensorflow.transformer.transformer_layers.DenseReluDense),
]

utils.run.variable_filter = "task_adapter"
